# Copyright 2022 The IREE Authors
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

load("//build_tools/bazel:build_defs.oss.bzl", "iree_cmake_extra_content", "iree_runtime_cc_library")
load("//build_tools/bazel:iree_bitcode_library.bzl", "iree_bitcode_library", "iree_link_bitcode")
load("//build_tools/embed_data:build_defs.bzl", "iree_c_embed_data")

package(
    default_visibility = ["//visibility:public"],
    features = ["layering_check"],
    licenses = ["notice"],  # Apache 2.0
)

iree_runtime_cc_library(
    name = "exported_bits",
    hdrs = ["exported_bits.h"],
)

internal_headers = [
    "common.h",
    "exported_bits.h",
    "mmt4d.h",
    "mmt4d_internal.h",
    "pack.h",
    "pack_internal.h",
    "query_tile_sizes.h",
    "query_tile_sizes_internal.h",
    "unpack.h",
    "unpack_internal.h",
]

# Filegroup used in Bazel only (only Bazel enforces header dependencies).
filegroup(
    name = "internal_headers_filegroup",
    srcs = internal_headers,
)

iree_runtime_cc_library(
    name = "internal_headers",
    hdrs = internal_headers,
    visibility = [":__subpackages__"],
    deps = [
        "//runtime/src/iree/base:core_headers",
    ],
)

iree_runtime_cc_library(
    name = "fallback",
    srcs = ["fallback.c"],
    hdrs = internal_headers,
    visibility = [":__subpackages__"],
)

# Entry points.
iree_runtime_cc_library(
    name = "ukernel",
    srcs = [
        "mmt4d.c",
        "mmt4d_tile_generic.c",
        "pack.c",
        "pack_tile.c",
        "query_tile_sizes.c",
        "unpack.c",
        "unpack_tile.c",
    ] + internal_headers,
    hdrs = ["api.h"],
    deps = [
        ":exported_bits",
        "//runtime/src/iree/base:core_headers",
        "//runtime/src/iree/builtins/ukernel/arch:ukernel_arch",
    ],
)

#===------------------------------------------------------------------------===#
# UKernel bitcode files
#===------------------------------------------------------------------------===#

iree_cmake_extra_content(
    content = """
if(IREE_BUILD_COMPILER AND IREE_TARGET_BACKEND_LLVM_CPU)
""",
    inline = True,
)

# Enumerate all archs for which to generate the build-system logic and targets
# to build ukernel bitcode. This doesn't necessarily imply actually building
# bitcode for all these architectures, as we can still have e.g. a CMake option
# to control which archs to build bitcode for, and even without that, we still
# have CMake logic omitting bitcode for archs which we are not targeting, so that
# disabling a target arch also disables the corresponding bitcode.
# As of early 2024, generic bitcode for an additional arch is ~ 20 kB, while
# bitcode for archs for which we have dedicated optimized ukernel code is ~ 100 kB.
bitcode_generic_archs = [
    "x86_64",
    "arm_64",
    "arm_32",
    "riscv_64",
    "riscv_32",
]

# Enumerate all archs for which we have arch-specific dedicated ukernel code,
# in a arch/${arch} subdirectory.
bitcode_specific_archs = [
    "x86_64",
    "arm_64",
    "riscv_64",
]

[iree_bitcode_library(
    name = "ukernel_bitcode_generic_%s" % arch,
    srcs = [
        "mmt4d.c",
        "mmt4d_tile_generic.c",
        "pack.c",
        "pack_tile.c",
        "unpack.c",
        "unpack_tile.c",
    ] + ([] if arch in bitcode_specific_archs else ["fallback.c"]),
    arch = arch,
    internal_hdrs = [
        ":internal_headers_filegroup",
        "//runtime/src/iree/schemas:cpu_data_headers_filegroup",
    ],
) for arch in bitcode_generic_archs]

[iree_link_bitcode(
    name = "ukernel_bitcode_%s" % arch,
    bitcode_files = [
        ":ukernel_bitcode_generic_%s.bc" % arch,
    ] + ([
        "arch/%s:ukernel_bitcode_arch_%s.bc" % (arch, arch),
    ] if arch in bitcode_specific_archs else []),
) for arch in bitcode_generic_archs]

iree_c_embed_data(
    name = "embed_ukernel_bitcode",
    srcs = [
        ":ukernel_bitcode_%s.bc" % arch
        for arch in bitcode_generic_archs
    ],
    c_file_output = "ukernel_bitcode.c",
    flatten = True,
    h_file_output = "ukernel_bitcode.h",
    identifier = "iree_ukernel_bitcode",
    deps = [
        "//runtime/src:runtime_defines",
    ],
)

iree_cmake_extra_content(
    content = """
endif()  # IREE_BUILD_COMPILER AND IREE_TARGET_BACKEND_LLVM_CPU
""",
    inline = True,
)
